package com.jl.springbean.util;

import com.jl.JLReflect;
import com.jl.JLTuple;
import lombok.AllArgsConstructor;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.stereotype.Component;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * spring bean工具
 * 修改bean只需要获取到bean对象再set修改属性即可
 */
@Component
@AllArgsConstructor
public class JLSpringBean {

    private ConfigurableApplicationContext applicationContext;

    /**
     * 注册bean
     *
     * @param clazz bean类
     * @param args  参数map
     * @param <T>
     */
    public <T> T registerBean(Class<T> clazz, Map<String, Object> args) {
        String name = getName(clazz);
        return registerBean(name, clazz, args);
    }

    /**
     * 注册bean
     *
     * @param name  beanName
     * @param clazz bean类
     * @param args  参数map
     * @param <T>
     */
    public <T> T registerBean(String name, Class<T> clazz, Map<String, Object> args) {
        if (existBean(name)) {
            throw new RuntimeException("bean重复：" + name);
        }
        BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(clazz);
        Set<String> keys = args.keySet();
        Iterator<String> iterator = keys.iterator();
        while (iterator.hasNext()) {
            String key = iterator.next();
            beanDefinitionBuilder.addPropertyValue(key, args.get(key));
        }
        BeanDefinition beanDefinition = beanDefinitionBuilder.getRawBeanDefinition();
        BeanDefinitionRegistry beanFactory = (BeanDefinitionRegistry) applicationContext.getBeanFactory();
        beanFactory.registerBeanDefinition(name, beanDefinition);
        return applicationContext.getBean(name, clazz);
    }

    /**
     * 注册bean
     *
     * @param obj bean对象
     * @param <T>
     */
    public <T> T registerBean(T obj) {
        String name = getName(obj.getClass());
        return registerBean(name, obj);
    }

    /**
     * 注册bean
     *
     * @param name beanName
     * @param obj  bean对象
     * @param <T>
     */
    public <T> T registerBean(String name, T obj) {
        Class<T> clazz = (Class<T>) obj.getClass();
        if (existBean(name)) {
            throw new RuntimeException("bean重复：" + name);
        }
        BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(clazz);
        List<JLTuple.Tuple3<String, Object, Class<?>>> propertys = JLReflect.PropertyReflect.getProperty(obj);
        for (JLTuple.Tuple3<String, Object, Class<?>> tuple3 : propertys) {
            String property = tuple3.getV1();
            Object value = tuple3.getV2();
            if (value != null) {
                beanDefinitionBuilder.addPropertyValue(property, value);
            }
        }
        BeanDefinition beanDefinition = beanDefinitionBuilder.getRawBeanDefinition();
        BeanDefinitionRegistry beanFactory = (BeanDefinitionRegistry) applicationContext.getBeanFactory();
        beanFactory.registerBeanDefinition(name, beanDefinition);
        return applicationContext.getBean(name, clazz);
    }

    /**
     * 删除bean
     *
     * @param name beanName
     */
    public void removeBean(String name) {
        if (existBean(name)) {
            DefaultListableBeanFactory defaultListableBeanFactory = (DefaultListableBeanFactory) applicationContext.getAutowireCapableBeanFactory();
            defaultListableBeanFactory.removeBeanDefinition(name);
        }
    }

    /**
     * 是否存在bean
     *
     * @param name beanName
     * @return
     */
    public boolean existBean(String name) {
        return applicationContext.containsBean(name);
    }


    /**
     * 获取beanName
     */
    private String getName(Class<?> clazz) {
        String result;
        String name = clazz.getName();
        result = name.substring(name.lastIndexOf(".") + 1);
        //首字母转小写
        result = result.substring(0, 1).toLowerCase() + result.substring(1);
        return result;
    }
}
